# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sympy as sm
from hysop.symbolic import Expr
from hysop.tools.htypes import first_not_None
from hysop.tools.numpywrappers import npw
[docs]
class NAryRelation(Expr):
"""
Represents relations bewteen n variables.
Parameters
----------
args: tuple of Expr
"""
@property
def rel_op(self):
return None
def __new__(cls, *exprs):
obj = super().__new__(cls, *exprs)
return obj
def __str__(self):
rel_op = f" {self.rel_op} "
return f"({rel_op.join(str(x) for x in self.args)})"
def _sympystr(self, printer):
rel_op = f" {self.rel_op} "
return "({})".format(rel_op.join(f"{printer._print(x)}" for x in self.args))
def _ccode(self, printer):
rel_op = f" {self.rel_op} "
return "({})".format(rel_op.join(f"{printer._print(x)}" for x in self.args))
@property
def is_number(self):
return True
@property
def free_symbols(self):
return ()
[docs]
class LogicalRelation(NAryRelation):
pass
[docs]
class ArithmeticRelation(NAryRelation):
pass
[docs]
class Add(ArithmeticRelation):
@property
def rel_op(self):
return "+"
[docs]
class Mul(ArithmeticRelation):
@property
def rel_op(self):
return "*"
[docs]
class Pow(ArithmeticRelation):
@property
def rel_op(self):
return "**"
[docs]
class LogicalAND(LogicalRelation):
@property
def rel_op(self):
return "&&"
[docs]
class LogicalOR(LogicalRelation):
@property
def rel_op(self):
return "||"
[docs]
class LogicalXOR(LogicalRelation):
@property
def rel_op(self):
return "^"
[docs]
class LogicalEQ(LogicalRelation):
@property
def rel_op(self):
return "=="
[docs]
class LogicalNE(LogicalRelation):
@property
def rel_op(self):
return "!="
[docs]
class LogicalLT(LogicalRelation):
@property
def rel_op(self):
return "<"
[docs]
class LogicalGT(LogicalRelation):
@property
def rel_op(self):
return ">"
[docs]
class LogicalLE(LogicalRelation):
@property
def rel_op(self):
return "<="
[docs]
class LogicalGE(LogicalRelation):
@property
def rel_op(self):
return ">="
[docs]
class BinaryRelation(NAryRelation):
"""
Represents relations bewteen 2 variables.
Parameters
----------
lhs : Expr
rhs : Expr
"""
def __new__(cls, lhs, rhs):
obj = super().__new__(cls, lhs, rhs)
obj.lhs = lhs
obj.rhs = rhs
return obj
[docs]
class Assignment(BinaryRelation):
"""
Represents variable assignment for code generation.
Parameters
----------
lhs : Expr
rhs : Expr
"""
def __str__(self):
lhs = first_not_None(getattr(self.lhs, "name", None), self.lhs)
rhs = first_not_None(getattr(self.rhs, "name", None), self.rhs)
rel_op = self.rel_op
if rel_op == "=":
rel_op = ":" + rel_op
return "{} {} {};".format(
lhs, rel_op, sm.printing.str.StrPrinter()._print(self.rhs)
)
def _ccode(self, printer):
try:
return self.lhs.declare(init=printer._print(self.rhs))
except:
return "{} {} {};".format(
printer._print(self.lhs), self.rel_op, printer._print(self.rhs)
)
@property
def rel_op(self):
return "="
[docs]
@classmethod
def assign(cls, lhs, rhs, skip_zero_rhs=False):
exprs = ()
def create_expr(rhs):
return (not skip_zero_rhs) or (rhs != 0)
if isinstance(lhs, npw.ndarray) and isinstance(rhs, npw.ndarray):
assert isinstance(lhs, npw.ndarray), type(lhs)
assert isinstance(rhs, npw.ndarray), type(rhs)
assert rhs.size == lhs.size
assert rhs.shape == lhs.shape
for l, r in zip(lhs.ravel().tolist(), rhs.ravel().tolist()):
if create_expr(r):
e = cls(l, r)
exprs += (e,)
elif isinstance(lhs, npw.ndarray) or isinstance(rhs, npw.ndarray):
if isinstance(lhs, npw.ndarray):
lhss = lhs.ravel().tolist()
rhss = (rhs,) * len(lhss)
else:
rhss = rhs.ravel().tolist()
lhss = (lhs,) * len(rhss)
for l, r in zip(lhss, rhss):
if create_expr(r):
e = cls(l, r)
exprs += (e,)
elif isinstance(lhs, sm.Basic) and isinstance(rhs, sm.Basic):
assert isinstance(lhs, sm.Basic), type(lhs)
assert isinstance(rhs, sm.Basic), type(rhs)
e = cls(lhs, rhs)
if create_expr(rhs):
exprs += (e,)
else:
msg = "Cannot handle operand types:\n *lhs: {}\n *rhs: {}\n"
msg = msg.format(type(lhs), type(rhs))
raise TypeError(msg)
return exprs
[docs]
class AugmentedAssignment(Assignment):
"""
Base class for augmented assignments
"""
@property
def rel_op(self):
return self._symbol + "="
[docs]
class AddAugmentedAssignment(AugmentedAssignment):
_symbol = "+"
[docs]
class SubAugmentedAssignment(AugmentedAssignment):
_symbol = "-"
[docs]
class MulAugmentedAssignment(AugmentedAssignment):
_symbol = "*"
[docs]
class DivAugmentedAssignment(AugmentedAssignment):
_symbol = "/"
[docs]
class ModAugmentedAssignment(AugmentedAssignment):
_symbol = "%"
[docs]
class NAryFunction(Expr):
"""
Represents relations bewteen n variables.
Parameters
----------
args: tuple of Expr
"""
@property
def fname(self):
raise NotImplemented
def __new__(cls, *exprs):
obj = super().__new__(cls, *exprs)
return obj
def __str__(self):
return "{}({})".format(self.fname, ", ".join(str(x) for x in self.args))
def _sympystr(self, printer):
return "{}({})".format(
self.fname, ", ".join(f"{printer._print(x)}" for x in self.args)
)
def _ccode(self, printer):
return "{}({})".format(
self.fname, ", ".join(f"{printer._print(x)}" for x in self.args)
)
@property
def is_number(self):
return True
@property
def free_symbols(self):
return ()
[docs]
class UnaryFunction(NAryFunction):
def __new__(cls, a):
return super().__new__(cls, a)
[docs]
class BinaryFunction(NAryFunction):
def __new__(cls, lhs, rhs):
return super().__new__(cls, lhs, rhs)
[docs]
class Max(BinaryFunction):
@property
def fname(self):
return "max"
[docs]
class Min(BinaryFunction):
@property
def fname(self):
return "min"
[docs]
class Round(UnaryFunction):
@property
def fname(self):
return "round"